"""WebDAV connector"""
import logging
import os
from datetime import datetime, timezone
from typing import Any, Optional

from webdav4.client import Client as WebDAVClient

from common.data_source.utils import (
    get_file_ext,
)
from common.data_source.config import DocumentSource, INDEX_BATCH_SIZE, BLOB_STORAGE_SIZE_THRESHOLD
from common.data_source.exceptions import (
    ConnectorMissingCredentialError,
    ConnectorValidationError,
    CredentialExpiredError,
    InsufficientPermissionsError
)
from common.data_source.interfaces import LoadConnector, PollConnector
from common.data_source.models import Document, SecondsSinceUnixEpoch, GenerateDocumentsOutput


class WebDAVConnector(LoadConnector, PollConnector):
    """WebDAV connector for syncing files from WebDAV servers"""

    def __init__(
        self,
        base_url: str,
        remote_path: str = "/",
        batch_size: int = INDEX_BATCH_SIZE,
    ) -> None:
        """Initialize WebDAV connector
        
        Args:
            base_url: Base URL of the WebDAV server (e.g., "https://webdav.example.com")
            remote_path: Remote path to sync from (default: "/")
            batch_size: Number of documents per batch
        """
        self.base_url = base_url.rstrip("/")
        if not remote_path:
            remote_path = "/"
        if not remote_path.startswith("/"):
            remote_path = f"/{remote_path}"
        if remote_path.endswith("/") and remote_path != "/":
            remote_path = remote_path.rstrip("/")
        self.remote_path = remote_path
        self.batch_size = batch_size
        self.client: Optional[WebDAVClient] = None
        self._allow_images: bool | None = None
        self.size_threshold: int | None = BLOB_STORAGE_SIZE_THRESHOLD

    def set_allow_images(self, allow_images: bool) -> None:
        """Set whether to process images"""
        logging.info(f"Setting allow_images to {allow_images}.")
        self._allow_images = allow_images

    def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
        """Load credentials and initialize WebDAV client
        
        Args:
            credentials: Dictionary containing 'username' and 'password'
        
        Returns:
            None
        
        Raises:
            ConnectorMissingCredentialError: If required credentials are missing
        """
        logging.debug(f"Loading credentials for WebDAV server {self.base_url}")

        username = credentials.get("username")
        password = credentials.get("password")
        
        if not username or not password:
            raise ConnectorMissingCredentialError(
                "WebDAV requires 'username' and 'password' credentials"
            )

        try:
            # Initialize WebDAV client
            self.client = WebDAVClient(
                base_url=self.base_url,
                auth=(username, password)
            )
            
            # Test connection
            self.client.exists(self.remote_path)
            
        except Exception as e:
            logging.error(f"Failed to connect to WebDAV server: {e}")
            raise ConnectorMissingCredentialError(
                f"Failed to authenticate with WebDAV server: {e}"
            )

        return None

    def _list_files_recursive(
        self, 
        path: str,
        start: datetime,
        end: datetime,
    ) -> list[tuple[str, dict]]:
        """Recursively list all files in the given path
        
        Args:
            path: Path to list files from
            start: Start datetime for filtering
            end: End datetime for filtering
            
        Returns:
            List of tuples containing (file_path, file_info)
        """
        if self.client is None:
            raise ConnectorMissingCredentialError("WebDAV client not initialized")

        files = []
        
        try:
            logging.debug(f"Listing directory: {path}")
            for item in self.client.ls(path, detail=True):
                item_path = item['name']
            
                if item_path == path or item_path == path + '/':
                    continue
                
                logging.debug(f"Found item: {item_path}, type: {item.get('type')}")

                if item.get('type') == 'directory':
                    try:
                        files.extend(self._list_files_recursive(item_path, start, end))
                    except Exception as e:
                        logging.error(f"Error recursing into directory {item_path}: {e}")
                        continue
                else:
                    try:
                        modified_time = item.get('modified')
                        if modified_time:
                            if isinstance(modified_time, datetime):
                                modified = modified_time
                                if modified.tzinfo is None:
                                    modified = modified.replace(tzinfo=timezone.utc)
                            elif isinstance(modified_time, str):
                                try:
                                    modified = datetime.strptime(modified_time, '%a, %d %b %Y %H:%M:%S %Z')
                                    modified = modified.replace(tzinfo=timezone.utc)
                                except (ValueError, TypeError):
                                    try:
                                        modified = datetime.fromisoformat(modified_time.replace('Z', '+00:00'))
                                    except (ValueError, TypeError):
                                        logging.warning(f"Could not parse modified time for {item_path}: {modified_time}")
                                        modified = datetime.now(timezone.utc)
                            else:
                                modified = datetime.now(timezone.utc)
                        else:
                            modified = datetime.now(timezone.utc)
                        

                        logging.debug(f"File {item_path}: modified={modified}, start={start}, end={end}, include={start < modified <= end}")
                        if start < modified <= end:
                            files.append((item_path, item))
                        else:
                            logging.debug(f"File {item_path} filtered out by time range")
                    except Exception as e:
                        logging.error(f"Error processing file {item_path}: {e}")
                        continue
                        
        except Exception as e:
            logging.error(f"Error listing directory {path}: {e}")
            
        return files

    def _yield_webdav_documents(
        self,
        start: datetime,
        end: datetime,
    ) -> GenerateDocumentsOutput:
        """Generate documents from WebDAV server
        
        Args:
            start: Start datetime for filtering
            end: End datetime for filtering
            
        Yields:
            Batches of documents
        """
        if self.client is None:
            raise ConnectorMissingCredentialError("WebDAV client not initialized")

        logging.info(f"Searching for files in {self.remote_path} between {start} and {end}")
        files = self._list_files_recursive(self.remote_path, start, end)
        logging.info(f"Found {len(files)} files matching time criteria")
        
        batch: list[Document] = []
        for file_path, file_info in files:
            file_name = os.path.basename(file_path)
            
            size_bytes = file_info.get('size', 0)
            if (
                self.size_threshold is not None
                and isinstance(size_bytes, int)
                and size_bytes > self.size_threshold
            ):
                logging.warning(
                    f"{file_name} exceeds size threshold of {self.size_threshold}. Skipping."
                )
                continue
            
            try:
                logging.debug(f"Downloading file: {file_path}")
                from io import BytesIO
                buffer = BytesIO()
                self.client.download_fileobj(file_path, buffer)
                blob = buffer.getvalue()
                
                if blob is None or len(blob) == 0:
                    logging.warning(f"Downloaded content is empty for {file_path}")
                    continue

                modified_time = file_info.get('modified')
                if modified_time:
                    if isinstance(modified_time, datetime):
                        modified = modified_time
                        if modified.tzinfo is None:
                            modified = modified.replace(tzinfo=timezone.utc)
                    elif isinstance(modified_time, str):
                        try:
                            modified = datetime.strptime(modified_time, '%a, %d %b %Y %H:%M:%S %Z')
                            modified = modified.replace(tzinfo=timezone.utc)
                        except (ValueError, TypeError):
                            try:
                                modified = datetime.fromisoformat(modified_time.replace('Z', '+00:00'))
                            except (ValueError, TypeError):
                                logging.warning(f"Could not parse modified time for {file_path}: {modified_time}")
                                modified = datetime.now(timezone.utc)
                    else:
                        modified = datetime.now(timezone.utc)
                else:
                    modified = datetime.now(timezone.utc)

                batch.append(
                    Document(
                        id=f"webdav:{self.base_url}:{file_path}",
                        blob=blob,
                        source=DocumentSource.WEBDAV,
                        semantic_identifier=file_name,
                        extension=get_file_ext(file_name),
                        doc_updated_at=modified,
                        size_bytes=size_bytes if size_bytes else 0
                    )
                )
                
                if len(batch) == self.batch_size:
                    yield batch
                    batch = []

            except Exception as e:
                logging.exception(f"Error downloading file {file_path}: {e}")
        
        if batch:
            yield batch

    def load_from_state(self) -> GenerateDocumentsOutput:
        """Load all documents from WebDAV server
        
        Yields:
            Batches of documents
        """
        logging.debug(f"Loading documents from WebDAV server {self.base_url}")
        return self._yield_webdav_documents(
            start=datetime(1970, 1, 1, tzinfo=timezone.utc),
            end=datetime.now(timezone.utc),
        )

    def poll_source(
        self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
    ) -> GenerateDocumentsOutput:
        """Poll WebDAV server for updated documents
        
        Args:
            start: Start timestamp (seconds since Unix epoch)
            end: End timestamp (seconds since Unix epoch)
            
        Yields:
            Batches of documents
        """
        if self.client is None:
            raise ConnectorMissingCredentialError("WebDAV client not initialized")

        start_datetime = datetime.fromtimestamp(start, tz=timezone.utc)
        end_datetime = datetime.fromtimestamp(end, tz=timezone.utc)

        for batch in self._yield_webdav_documents(start_datetime, end_datetime):
            yield batch

    def validate_connector_settings(self) -> None:
        """Validate WebDAV connector settings
        
        Raises:
            ConnectorMissingCredentialError: If credentials are not loaded
            ConnectorValidationError: If settings are invalid
        """
        if self.client is None:
            raise ConnectorMissingCredentialError(
                "WebDAV credentials not loaded."
            )

        if not self.base_url:
            raise ConnectorValidationError(
                "No base URL was provided in connector settings."
            )

        try:
            if not self.client.exists(self.remote_path):
                raise ConnectorValidationError(
                    f"Remote path '{self.remote_path}' does not exist on WebDAV server."
                )

        except Exception as e:
            error_message = str(e)
            
            if "401" in error_message or "unauthorized" in error_message.lower():
                raise CredentialExpiredError(
                    "WebDAV credentials appear invalid or expired."
                )
            
            if "403" in error_message or "forbidden" in error_message.lower():
                raise InsufficientPermissionsError(
                    f"Insufficient permissions to access path '{self.remote_path}' on WebDAV server."
                )
            
            if "404" in error_message or "not found" in error_message.lower():
                raise ConnectorValidationError(
                    f"Remote path '{self.remote_path}' does not exist on WebDAV server."
                )

            raise ConnectorValidationError(
                f"Unexpected WebDAV client error: {e}"
            )


if __name__ == "__main__":
    credentials_dict = {
        "username": os.environ.get("WEBDAV_USERNAME"),
        "password": os.environ.get("WEBDAV_PASSWORD"),
    }

    connector = WebDAVConnector(
        base_url=os.environ.get("WEBDAV_URL") or "https://webdav.example.com",
        remote_path=os.environ.get("WEBDAV_PATH") or "/",
    )

    try:
        connector.load_credentials(credentials_dict)
        connector.validate_connector_settings()
        
        document_batch_generator = connector.load_from_state()
        for document_batch in document_batch_generator:
            print("First batch of documents:")
            for doc in document_batch:
                print(f"Document ID: {doc.id}")
                print(f"Semantic Identifier: {doc.semantic_identifier}")
                print(f"Source: {doc.source}")
                print(f"Updated At: {doc.doc_updated_at}")
                print("---")
            break

    except ConnectorMissingCredentialError as e:
        print(f"Error: {e}")
    except Exception as e:
        print(f"An unexpected error occurred: {e}")
